# Natural Language Toolkit: Interface to Weka Classsifiers
#
# Copyright (C) 2001-2020 NLTK Project
# Author: Edward Loper <edloper@gmail.com>
# URL: <http://nltk.org/>
# For license information, see LICENSE.TXT

"""
Classifiers that make use of the external 'Weka' package.
"""

import time
import tempfile
import os
import subprocess
import re
import zipfile
from sys import stdin

from nltk.probability import DictionaryProbDist
from nltk.internals import java, config_java

from nltk.classify.api import ClassifierI

_weka_classpath = None
_weka_search = [
    ".",
    "/usr/share/weka",
    "/usr/local/share/weka",
    "/usr/lib/weka",
    "/usr/local/lib/weka",
]


def config_weka(classpath=None):
    global _weka_classpath

    # Make sure java's configured first.
    config_java()

    if classpath is not None:
        _weka_classpath = classpath

    if _weka_classpath is None:
        searchpath = _weka_search
        if "WEKAHOME" in os.environ:
            searchpath.insert(0, os.environ["WEKAHOME"])

        for path in searchpath:
            if os.path.exists(os.path.join(path, "weka.jar")):
                _weka_classpath = os.path.join(path, "weka.jar")
                version = _check_weka_version(_weka_classpath)
                if version:
                    print(
                        ("[Found Weka: %s (version %s)]" % (_weka_classpath, version))
                    )
                else:
                    print("[Found Weka: %s]" % _weka_classpath)
                _check_weka_version(_weka_classpath)

    if _weka_classpath is None:
        raise LookupError(
            "Unable to find weka.jar!  Use config_weka() "
            "or set the WEKAHOME environment variable. "
            "For more information about Weka, please see "
            "http://www.cs.waikato.ac.nz/ml/weka/"
        )


def _check_weka_version(jar):
    try:
        zf = zipfile.ZipFile(jar)
    except (SystemExit, KeyboardInterrupt):
        raise
    except:
        return None
    try:
        try:
            return zf.read("weka/core/version.txt")
        except KeyError:
            return None
    finally:
        zf.close()


class WekaClassifier(ClassifierI):
    def __init__(self, formatter, model_filename):
        self._formatter = formatter
        self._model = model_filename

    def prob_classify_many(self, featuresets):
        return self._classify_many(featuresets, ["-p", "0", "-distribution"])

    def classify_many(self, featuresets):
        return self._classify_many(featuresets, ["-p", "0"])

    def _classify_many(self, featuresets, options):
        # Make sure we can find java & weka.
        config_weka()

        temp_dir = tempfile.mkdtemp()
        try:
            # Write the test data file.
            test_filename = os.path.join(temp_dir, "test.arff")
            self._formatter.write(test_filename, featuresets)

            # Call weka to classify the data.
            cmd = [
                "weka.classifiers.bayes.NaiveBayes",
                "-l",
                self._model,
                "-T",
                test_filename,
            ] + options
            (stdout, stderr) = java(
                cmd,
                classpath=_weka_classpath,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
            )

            # Check if something went wrong:
            if stderr and not stdout:
                if "Illegal options: -distribution" in stderr:
                    raise ValueError(
                        "The installed version of weka does "
                        "not support probability distribution "
                        "output."
                    )
                else:
                    raise ValueError("Weka failed to generate output:\n%s" % stderr)

            # Parse weka's output.
            return self.parse_weka_output(stdout.decode(stdin.encoding).split("\n"))

        finally:
            for f in os.listdir(temp_dir):
                os.remove(os.path.join(temp_dir, f))
            os.rmdir(temp_dir)

    def parse_weka_distribution(self, s):
        probs = [float(v) for v in re.split("[*,]+", s) if v.strip()]
        probs = dict(zip(self._formatter.labels(), probs))
        return DictionaryProbDist(probs)

    def parse_weka_output(self, lines):
        # Strip unwanted text from stdout
        for i, line in enumerate(lines):
            if line.strip().startswith("inst#"):
                lines = lines[i:]
                break

        if lines[0].split() == ["inst#", "actual", "predicted", "error", "prediction"]:
            return [line.split()[2].split(":")[1] for line in lines[1:] if line.strip()]
        elif lines[0].split() == [
            "inst#",
            "actual",
            "predicted",
            "error",
            "distribution",
        ]:
            return [
                self.parse_weka_distribution(line.split()[-1])
                for line in lines[1:]
                if line.strip()
            ]

        # is this safe:?
        elif re.match(r"^0 \w+ [01]\.[0-9]* \?\s*$", lines[0]):
            return [line.split()[1] for line in lines if line.strip()]

        else:
            for line in lines[:10]:
                print(line)
            raise ValueError(
                "Unhandled output format -- your version "
                "of weka may not be supported.\n"
                "  Header: %s" % lines[0]
            )

    # [xx] full list of classifiers (some may be abstract?):
    # ADTree, AODE, BayesNet, ComplementNaiveBayes, ConjunctiveRule,
    # DecisionStump, DecisionTable, HyperPipes, IB1, IBk, Id3, J48,
    # JRip, KStar, LBR, LeastMedSq, LinearRegression, LMT, Logistic,
    # LogisticBase, M5Base, MultilayerPerceptron,
    # MultipleClassifiersCombiner, NaiveBayes, NaiveBayesMultinomial,
    # NaiveBayesSimple, NBTree, NNge, OneR, PaceRegression, PART,
    # PreConstructedLinearModel, Prism, RandomForest,
    # RandomizableClassifier, RandomTree, RBFNetwork, REPTree, Ridor,
    # RuleNode, SimpleLinearRegression, SimpleLogistic,
    # SingleClassifierEnhancer, SMO, SMOreg, UserClassifier, VFI,
    # VotedPerceptron, Winnow, ZeroR

    _CLASSIFIER_CLASS = {
        "naivebayes": "weka.classifiers.bayes.NaiveBayes",
        "C4.5": "weka.classifiers.trees.J48",
        "log_regression": "weka.classifiers.functions.Logistic",
        "svm": "weka.classifiers.functions.SMO",
        "kstar": "weka.classifiers.lazy.KStar",
        "ripper": "weka.classifiers.rules.JRip",
    }

    @classmethod
    def train(
        cls,
        model_filename,
        featuresets,
        classifier="naivebayes",
        options=[],
        quiet=True,
    ):
        # Make sure we can find java & weka.
        config_weka()

        # Build an ARFF formatter.
        formatter = ARFF_Formatter.from_train(featuresets)

        temp_dir = tempfile.mkdtemp()
        try:
            # Write the training data file.
            train_filename = os.path.join(temp_dir, "train.arff")
            formatter.write(train_filename, featuresets)

            if classifier in cls._CLASSIFIER_CLASS:
                javaclass = cls._CLASSIFIER_CLASS[classifier]
            elif classifier in cls._CLASSIFIER_CLASS.values():
                javaclass = classifier
            else:
                raise ValueError("Unknown classifier %s" % classifier)

            # Train the weka model.
            cmd = [javaclass, "-d", model_filename, "-t", train_filename]
            cmd += list(options)
            if quiet:
                stdout = subprocess.PIPE
            else:
                stdout = None
            java(cmd, classpath=_weka_classpath, stdout=stdout)

            # Return the new classifier.
            return WekaClassifier(formatter, model_filename)

        finally:
            for f in os.listdir(temp_dir):
                os.remove(os.path.join(temp_dir, f))
            os.rmdir(temp_dir)


class ARFF_Formatter:
    """
    Converts featuresets and labeled featuresets to ARFF-formatted
    strings, appropriate for input into Weka.

    Features and classes can be specified manually in the constructor, or may
    be determined from data using ``from_train``.
    """

    def __init__(self, labels, features):
        """
        :param labels: A list of all class labels that can be generated.
        :param features: A list of feature specifications, where
            each feature specification is a tuple (fname, ftype);
            and ftype is an ARFF type string such as NUMERIC or
            STRING.
        """
        self._labels = labels
        self._features = features

    def format(self, tokens):
        """Returns a string representation of ARFF output for the given data."""
        return self.header_section() + self.data_section(tokens)

    def labels(self):
        """Returns the list of classes."""
        return list(self._labels)

    def write(self, outfile, tokens):
        """Writes ARFF data to a file for the given data."""
        if not hasattr(outfile, "write"):
            outfile = open(outfile, "w")
        outfile.write(self.format(tokens))
        outfile.close()

    @staticmethod
    def from_train(tokens):
        """
        Constructs an ARFF_Formatter instance with class labels and feature
        types determined from the given data. Handles boolean, numeric and
        string (note: not nominal) types.
        """
        # Find the set of all attested labels.
        labels = set(label for (tok, label) in tokens)

        # Determine the types of all features.
        features = {}
        for tok, label in tokens:
            for (fname, fval) in tok.items():
                if issubclass(type(fval), bool):
                    ftype = "{True, False}"
                elif issubclass(type(fval), (int, float, bool)):
                    ftype = "NUMERIC"
                elif issubclass(type(fval), str):
                    ftype = "STRING"
                elif fval is None:
                    continue  # can't tell the type.
                else:
                    raise ValueError("Unsupported value type %r" % ftype)

                if features.get(fname, ftype) != ftype:
                    raise ValueError("Inconsistent type for %s" % fname)
                features[fname] = ftype
        features = sorted(features.items())

        return ARFF_Formatter(labels, features)

    def header_section(self):
        """Returns an ARFF header as a string."""
        # Header comment.
        s = (
            "% Weka ARFF file\n"
            + "% Generated automatically by NLTK\n"
            + "%% %s\n\n" % time.ctime()
        )

        # Relation name
        s += "@RELATION rel\n\n"

        # Input attribute specifications
        for fname, ftype in self._features:
            s += "@ATTRIBUTE %-30r %s\n" % (fname, ftype)

        # Label attribute specification
        s += "@ATTRIBUTE %-30r {%s}\n" % ("-label-", ",".join(self._labels))

        return s

    def data_section(self, tokens, labeled=None):
        """
        Returns the ARFF data section for the given data.

        :param tokens: a list of featuresets (dicts) or labelled featuresets
            which are tuples (featureset, label).
        :param labeled: Indicates whether the given tokens are labeled
            or not.  If None, then the tokens will be assumed to be
            labeled if the first token's value is a tuple or list.
        """
        # Check if the tokens are labeled or unlabeled.  If unlabeled,
        # then use 'None'
        if labeled is None:
            labeled = tokens and isinstance(tokens[0], (tuple, list))
        if not labeled:
            tokens = [(tok, None) for tok in tokens]

        # Data section
        s = "\n@DATA\n"
        for (tok, label) in tokens:
            for fname, ftype in self._features:
                s += "%s," % self._fmt_arff_val(tok.get(fname))
            s += "%s\n" % self._fmt_arff_val(label)

        return s

    def _fmt_arff_val(self, fval):
        if fval is None:
            return "?"
        elif isinstance(fval, (bool, int)):
            return "%s" % fval
        elif isinstance(fval, float):
            return "%r" % fval
        else:
            return "%r" % fval


if __name__ == "__main__":
    from nltk.classify.util import names_demo, binary_names_demo_features

    def make_classifier(featuresets):
        return WekaClassifier.train("/tmp/name.model", featuresets, "C4.5")

    classifier = names_demo(make_classifier, binary_names_demo_features)
